import numpy as np
import torch
from utils import get_sentence_token_spans


def get_answer_base(context, question, agents_dict, args):
    """
    Generate a basic answer to the question using a QA agent.

    Parameters
    ----------
    context : str
        The context passage to use for answering the question.

    question : str
        The question to answer.

    agents_dict : dict
        Dictionary containing initialized QA agents.

    args : argparse.Namespace
        Parsed arguments containing configurations.

    Returns
    -------
    str
        The answer generated by the QA agent.
    """
    return agents_dict["qa"].get_answer(context, question)


def get_answer_cot(context, question, agents_dict, args):
    """
    Generate an answer using a chain-of-thought reasoning agent.

    Parameters
    ----------
    context : str
        The context passage to use for answering the question.

    question : str
        The question to answer.

    agents_dict : dict
        Dictionary containing initialized QA agents.

    args : argparse.Namespace
        Parsed arguments containing configurations.

    Returns
    -------
    str
        The answer generated by the chain-of-thought agent.
    """
    return agents_dict["cot"].get_answer(context, question)


def get_answer_fullelicit(context, question, agents_dict, args):
    """
    Generate an answer using a QA agent after marking the entire context as important.

    Parameters
    ----------
    context : str
        The context passage to use for answering the question.

    question : str
        The question to answer.

    agents_dict : dict
        Dictionary containing initialized QA agents.

    args : argparse.Namespace
        Parsed arguments containing configurations, including markers for evidence.

    Returns
    -------
    str
        The answer generated by the QA agent.
    """
    # Mark the entire context as important by adding evidence markers.
    context = f"{args.marker_impstart} {context} {args.marker_impend}"
    # Use the "qa" agent to answer the question with the marked context.
    return agents_dict["qa"].get_answer(context, question)


def get_answer_promptelicit(
    context, question, agents_dict, args, return_evidence=False
):
    """
    Generate an answer by eliciting evidence from a prompting-based agent.

    Parameters
    ----------
    context : str
        The context passage to use for answering the question.

    question : str
        The question to answer.

    agents_dict : dict
        Dictionary containing initialized agents, including prompt elicitation and QA agents.

    args : argparse.Namespace
        Parsed arguments containing configurations, such as markers for evidence and max generation tokens.

    return_evidence : bool, optional (default=False)
        If True, return the selected evidence sentences along with the answer.

    Returns
    -------
    str or tuple
        The answer generated by the QA agent. If `return_evidence` is True, also returns the selected evidence sentences.
    """

    # Internal function to extract evidence sentences using the "pe" agent.
    def prompt_elicit(
        agent_elicit,
        context,
        question,
        marker_impstart,
        marker_impend,
        max_gen_tokens,
    ):
        """
        Perform prompt-based evidence elicitation.

        Parameters
        ----------
        agent_elicit : object
            The agent used for prompt-based evidence elicitation.

        context : str
            The context passage to process.

        question : str
            The question to answer.

        marker_impstart : str
            Marker indicating the start of important evidence.

        marker_impend : str
            Marker indicating the end of important evidence.

        max_gen_tokens : int
            Maximum number of tokens to generate for evidence extraction.

        Returns
        -------
        elicited_context : str
            Context with evidence sentences marked.

        evidence_sents : list of str
            List of evidence sentences extracted from the context.
        """
        # Use the "pe" agent to generate evidence sentences from the context.
        model_ans_raw = agent_elicit.get_answer(
            context, question, max_ans_tokens=max_gen_tokens
        )
        elicited_context = f"{context}"
        evidence_sents = []

        # Parse and identify evidence sentences in the context.
        for sent in [
            sent.lstrip("- ").lstrip('"').rstrip('"')
            for sent in model_ans_raw.split("\n")
        ]:
            if context.find(sent) > -1:
                # Locate evidence sentence positions in the context.
                sent_start = context.find(sent)
                sent_end = sent_start + len(sent)
                # Insert evidence markers around the identified sentence.
                elicited_context = (
                    elicited_context[:sent_start]
                    + f"{marker_impstart} {sent} {marker_impend}"
                    + elicited_context[sent_end:]
                )
                evidence_sents.append(sent)
        return elicited_context, evidence_sents

    # Perform evidence elicitation and highlight key sentences.
    elicited_context, evidence_sents = prompt_elicit(
        agents_dict["pe"],
        context,
        question,
        args.marker_impstart,
        args.marker_impend,
        args.max_ans_tokens,
    )
    # Use the "se" agent to generate the final answer based on the highlighted context.
    model_ans = agents_dict["se"].get_answer(elicited_context, question)

    # Return the answer and optionally the evidence sentences.
    if return_evidence:
        return model_ans, evidence_sents
    else:
        return model_ans


def get_answer_selfelicit(
    context, question, agents_dict, device, args, return_evidence=False
):
    """
    Generate an answer by self-elicit evidence using model attention patterns.

    Parameters
    ----------
    context : str
        The context passage to use for answering the question.

    question : str
        The question to answer.

    agents_dict : dict
        Dictionary containing initialized agents, including self-elicit and QA agents.

    device : torch.device
        Device on which the model computations are performed.

    args : argparse.Namespace
        Parsed arguments containing configurations, such as markers, layer spans, and thresholds.

    return_evidence : bool, optional (default=False)
        If True, return the selected evidence sentences along with the answer.

    Returns
    -------
    str or tuple
        The answer generated by the QA agent. If `return_evidence` is True, also returns the selected evidence sentences.
    """

    # Nested function for self-elicit logic
    def self_elicit(
        output_att,
        sents,
        sent_spans,
        context_span,
        marker_impstart,
        marker_impend,
        layer_span,
        threshold,
        verbose=False,
    ):
        """
        Perform evidence selection using attention scores.

        Parameters
        ----------
        output_att : list of torch.Tensor
            Attention outputs from the model.

        sents : list of str
            List of sentences in the context.

        sent_spans : list of tuple
            Token spans for each sentence.

        context_span : tuple
            Token span for the entire context.

        marker_impstart : str
            Marker indicating the start of important evidence.

        marker_impend : str
            Marker indicating the end of important evidence.

        layer_span : tuple of int
            Range of layers to consider for evidence selection.

        threshold : float
            Threshold for selecting evidence sentences.

        verbose : bool, optional
            If True, print debugging information about the process.

        Returns
        -------
        elicited_context : str
            Context with evidence sentences marked.

        evidence_sents : list of str
            List of evidence sentences.

        evidence_spans : list of tuple
            Token spans for the evidence sentences.
        """
        # Compute attention scores for the specified range of layers.
        att_layer_scores = np.array(
            [
                output_att[l][0, :, -1, context_span[0] : context_span[1]]
                .detach()
                .cpu()
                .float()
                .numpy()
                .mean(axis=0)
                for l in range(layer_span[0], layer_span[1])
            ]
        )
        # Normalize the attention scores across layers.
        att_layer_scores /= att_layer_scores.sum(axis=1, keepdims=True)

        # Aggregate token-level scores into sentence-level scores.
        att_token_scores = att_layer_scores.mean(axis=0)
        sent_scores = np.array(
            [
                att_token_scores[sent_span[0] : sent_span[1]].mean()
                for sent_span in sent_spans
            ]
        )
        # Select sentences with scores exceeding the threshold.
        target_sent_index = (sent_scores >= sent_scores.max() * threshold).nonzero()[0]

        if verbose:
            print(f"Sentences scores: {sent_scores.round(2)}")
            print(f"Target sentence index: {target_sent_index}")

        elicited_context = ""
        sent_end = "\n"
        evidence_sents = []
        for i, sent in enumerate(sents):
            if i in target_sent_index and len(sent.replace(" ", "")) > 5:
                # Add evidence markers for selected sentences.
                elicited_context += (
                    f"{marker_impstart} {sent} {marker_impend} {sent_end}"
                )
                evidence_sents.append(sent)
            else:
                elicited_context += f"{sent} {sent_end}"

        # Collect token spans for selected evidence sentences.
        evidence_spans = [sent_spans[i] for i in target_sent_index]

        return elicited_context, evidence_sents, evidence_spans

    # Prepare input tokens and compute attention scores.
    input_ids = (
        agents_dict["qa"]
        .get_chat_template_input_ids(context, question, return_tensors="pt")
        .to(device)
    )
    context_span = agents_dict["qa"].get_context_token_span(context, question)
    context_ids = input_ids[:, context_span[0] : context_span[1]]
    # Tokenize the context and identify sentence spans.
    sent_spans, sents = get_sentence_token_spans(
        context_ids, agents_dict["qa"].tokenizer
    )
    # Run the model and retrieve attention outputs.
    outputs = agents_dict["qa"].model(
        input_ids,
        output_attentions=True,
        attention_mask=torch.ones_like(input_ids),
    )
    output_att = outputs.attentions
    n_layers = len(output_att)
    # Define the layer range for evidence selection.
    layer_span = (
        int(args.layer_span[0] * n_layers),
        int(args.layer_span[1] * n_layers),
    )
    # Perform evidence elicitation using the computed attention patterns.
    elicited_context, evidence_sents, evidence_spans = self_elicit(
        output_att,
        sents,
        sent_spans,
        context_span,
        args.marker_impstart,
        args.marker_impend,
        layer_span=layer_span,
        threshold=args.alpha,
    )
    # Free GPU memory after computation.
    del outputs
    torch.cuda.empty_cache()
    # Use the "se" agent to generate the final answer based on the elicited context.
    model_ans = agents_dict["se"].get_answer(elicited_context, question)
    # Return the answer and optionally the evidence sentences.
    if return_evidence:
        return model_ans, evidence_sents
    else:
        return model_ans
